{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Accompanying code examples of the book \"Introduction to Artificial Neural Networks and Deep Learning: A Practical Guide with Applications in Python\" by [Sebastian Raschka](https://sebastianraschka.com). All code examples are released under the [MIT license](https://github.com/rasbt/deep-learning-book/blob/master/LICENSE). If you find this content useful, please consider supporting the work by buying a [copy of the book](https://leanpub.com/ann-and-deeplearning).*\n",
    "  \n",
    "Other code examples and content are available on [GitHub](https://github.com/rasbt/deep-learning-book). The PDF and ebook versions of the book are available through [Leanpub](https://leanpub.com/ann-and-deeplearning)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.6.0\n",
      "IPython 6.0.0\n",
      "\n",
      "tensorflow 1.1.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p tensorflow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Zoo -- Saving and Loading Trained Models \n",
    "\n",
    "## from TensorFlow Checkpoint Files and NumPy NPZ Archives"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook demonstrates different strategies on how to export and import training TensorFlow models based on a a simple 2-hidden layer multilayer perceptron. These include\n",
    "\n",
    "- Using regular TensorFlow meta and checkpoint files\n",
    "- Loading variables from NumPy archives (.npz) files\n",
    "\n",
    "Note that the graph def is going set up in a way that it constructs \"rigid,\" not trainable TensorFlow classifier if .npz files are provided. This is on purpose, since it may come handy in certain use cases, but the code can be easily modified to make the model trainable if NumPy .npz files are provided -- for example, by wrapping the `tf.constant` calls in `fc_layer` in a `tf.Variable` constructor like so:\n",
    "\n",
    "```python\n",
    "...\n",
    "if weight_params is not None:\n",
    "    weights = tf.Variable(tf.constant(weight_params, name='weights',\n",
    "                                      dtype=tf.float32))\n",
    "...\n",
    "```\n",
    "\n",
    "instead of \n",
    "\n",
    "```python\n",
    "...\n",
    "if weight_params is not None:\n",
    "    weights = tf.constant(weight_params, name='weights',\n",
    "                          dtype=tf.float32)\n",
    "...\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Multilayer Perceptron Graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following code cells defines wrapper functions for our convenience; it saves us some re-typing later when we set up the TensorFlow multilayer perceptron graphs for the trainable and non-trainable models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "\n",
    "##########################\n",
    "### WRAPPER FUNCTIONS\n",
    "##########################\n",
    "\n",
    "\n",
    "def fc_layer(input_tensor, n_output_units, name,\n",
    "             activation_fn=None, seed=None,\n",
    "             weight_params=None, bias_params=None):\n",
    "\n",
    "    with tf.variable_scope(name):\n",
    "\n",
    "        if weight_params is not None:\n",
    "            weights = tf.constant(weight_params, name='weights',\n",
    "                                  dtype=tf.float32)\n",
    "        else:\n",
    "            weights = tf.Variable(tf.truncated_normal(\n",
    "                shape=[input_tensor.get_shape().as_list()[-1], n_output_units],\n",
    "                    mean=0.0,\n",
    "                    stddev=0.1,\n",
    "                    dtype=tf.float32,\n",
    "                    seed=seed),\n",
    "                name='weights',)\n",
    "\n",
    "        if bias_params is not None:\n",
    "            biases = tf.constant(bias_params, name='biases', \n",
    "                                 dtype=tf.float32)\n",
    "\n",
    "        else:\n",
    "            biases = tf.Variable(tf.zeros(shape=[n_output_units]),\n",
    "                                 name='biases', \n",
    "                                 dtype=tf.float32)\n",
    "\n",
    "        act = tf.matmul(input_tensor, weights) + biases\n",
    "\n",
    "        if activation_fn is not None:\n",
    "            act = activation_fn(act)\n",
    "\n",
    "    return act\n",
    "\n",
    "\n",
    "def mlp_graph(n_input=784, n_classes=10, n_hidden_1=128, n_hidden_2=256,\n",
    "              learning_rate=0.1,\n",
    "              fixed_params=None):\n",
    "    \n",
    "    # fixed_params to allow loading weights & biases\n",
    "    # from NumPy npz archives and defining a fixed, non-trainable\n",
    "    # TensorFlow classifier\n",
    "    if not fixed_params:\n",
    "        var_names = ['fc1/weights:0', 'fc1/biases:0',\n",
    "                     'fc2/weights:0', 'fc2/biases:0',\n",
    "                     'logits/weights:0', 'logits/biases:0',]\n",
    "        \n",
    "        fixed_params = {v: None for v in var_names}\n",
    "        found_params = False\n",
    "    else:\n",
    "        found_params = True\n",
    "    \n",
    "    # Input data\n",
    "    tf_x = tf.placeholder(tf.float32, [None, n_input], name='features')\n",
    "    tf_y = tf.placeholder(tf.int32, [None], name='targets')\n",
    "    tf_y_onehot = tf.one_hot(tf_y, depth=n_classes, name='onehot_targets')\n",
    "\n",
    "    # Multilayer perceptron\n",
    "    fc1 = fc_layer(input_tensor=tf_x, \n",
    "                   n_output_units=n_hidden_1, \n",
    "                   name='fc1',\n",
    "                   weight_params=fixed_params['fc1/weights:0'], \n",
    "                   bias_params=fixed_params['fc1/biases:0'],\n",
    "                   activation_fn=tf.nn.relu)\n",
    "\n",
    "    fc2 = fc_layer(input_tensor=fc1, \n",
    "                   n_output_units=n_hidden_2, \n",
    "                   name='fc2',\n",
    "                   weight_params=fixed_params['fc2/weights:0'], \n",
    "                   bias_params=fixed_params['fc2/biases:0'],\n",
    "                   activation_fn=tf.nn.relu)\n",
    "    \n",
    "    logits = fc_layer(input_tensor=fc2, \n",
    "                      n_output_units=n_classes, \n",
    "                      name='logits',\n",
    "                      weight_params=fixed_params['logits/weights:0'], \n",
    "                      bias_params=fixed_params['logits/biases:0'],\n",
    "                      activation_fn=tf.nn.relu)\n",
    "    \n",
    "    # Loss and optimizer\n",
    "    ### Only necessary if no existing params are found\n",
    "    ### and a trainable graph has to be initialized\n",
    "    if not found_params:\n",
    "        loss = tf.nn.softmax_cross_entropy_with_logits(\n",
    "            logits=logits, labels=tf_y_onehot)\n",
    "        cost = tf.reduce_mean(loss, name='cost')\n",
    "        optimizer = tf.train.GradientDescentOptimizer(\n",
    "            learning_rate=learning_rate)\n",
    "        train = optimizer.minimize(cost, name='train')\n",
    "\n",
    "    # Prediction\n",
    "    probabilities = tf.nn.softmax(logits, name='probabilities')\n",
    "    labels = tf.cast(tf.argmax(logits, 1), tf.int32, name='labels')\n",
    "    \n",
    "    correct_prediction = tf.equal(labels, \n",
    "                                  tf_y, name='correct_predictions')\n",
    "    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32),\n",
    "                              name='accuracy')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train and Save Multilayer Perceptron"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./train-images-idx3-ubyte.gz\n",
      "Extracting ./train-labels-idx1-ubyte.gz\n",
      "Extracting ./t10k-images-idx3-ubyte.gz\n",
      "Extracting ./t10k-labels-idx1-ubyte.gz\n",
      "Epoch: 001 | AvgCost: 0.366 | Train/Valid ACC: 0.944/0.948\n",
      "Epoch: 002 | AvgCost: 0.163 | Train/Valid ACC: 0.965/0.963\n",
      "Epoch: 003 | AvgCost: 0.118 | Train/Valid ACC: 0.972/0.969\n",
      "Epoch: 004 | AvgCost: 0.093 | Train/Valid ACC: 0.979/0.974\n",
      "Epoch: 005 | AvgCost: 0.076 | Train/Valid ACC: 0.984/0.977\n",
      "Epoch: 006 | AvgCost: 0.062 | Train/Valid ACC: 0.986/0.974\n",
      "Epoch: 007 | AvgCost: 0.052 | Train/Valid ACC: 0.990/0.977\n",
      "Epoch: 008 | AvgCost: 0.044 | Train/Valid ACC: 0.988/0.975\n",
      "Epoch: 009 | AvgCost: 0.037 | Train/Valid ACC: 0.991/0.978\n",
      "Epoch: 010 | AvgCost: 0.032 | Train/Valid ACC: 0.994/0.979\n",
      "Test ACC: 0.976\n"
     ]
    }
   ],
   "source": [
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "\n",
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Hyperparameters\n",
    "learning_rate = 0.1\n",
    "training_epochs = 10\n",
    "batch_size = 64\n",
    "\n",
    "##########################\n",
    "### GRAPH DEFINITION\n",
    "##########################\n",
    "\n",
    "g = tf.Graph()\n",
    "with g.as_default():\n",
    "    mlp_graph()\n",
    "\n",
    "##########################\n",
    "### DATASET\n",
    "##########################\n",
    "\n",
    "mnist = input_data.read_data_sets(\"./\", one_hot=False)\n",
    "\n",
    "##########################\n",
    "### TRAINING & EVALUATION\n",
    "##########################\n",
    "\n",
    "with tf.Session(graph=g) as sess:\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    saver0 = tf.train.Saver()\n",
    "    \n",
    "    for epoch in range(training_epochs):\n",
    "        avg_cost = 0.\n",
    "        total_batch = mnist.train.num_examples // batch_size\n",
    "\n",
    "        for i in range(total_batch):\n",
    "            batch_x, batch_y = mnist.train.next_batch(batch_size)\n",
    "            _, c = sess.run(['train', 'cost:0'], feed_dict={'features:0': batch_x,\n",
    "                                                            'targets:0': batch_y})\n",
    "            avg_cost += c\n",
    "        \n",
    "        train_acc = sess.run('accuracy:0', feed_dict={'features:0': mnist.train.images,\n",
    "                                                      'targets:0': mnist.train.labels})\n",
    "        valid_acc = sess.run('accuracy:0', feed_dict={'features:0': mnist.validation.images,\n",
    "                                                      'targets:0': mnist.validation.labels})  \n",
    "        \n",
    "        print(\"Epoch: %03d | AvgCost: %.3f\" % (epoch + 1, avg_cost / (i + 1)), end=\"\")\n",
    "        print(\" | Train/Valid ACC: %.3f/%.3f\" % (train_acc, valid_acc))\n",
    "        \n",
    "    test_acc = sess.run('accuracy:0', feed_dict={'features:0': mnist.test.images,\n",
    "                                                 'targets:0': mnist.test.labels})\n",
    "    print('Test ACC: %.3f' % test_acc)\n",
    "    \n",
    "    ##########################\n",
    "    ### SAVE TRAINED MODEL\n",
    "    ##########################\n",
    "    saver0.save(sess, save_path='./mlp')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reload Model from Meta and Checkpoint Files"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**You can restart and the notebook and the following code cells should execute without any additional code dependencies.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./train-images-idx3-ubyte.gz\n",
      "Extracting ./train-labels-idx1-ubyte.gz\n",
      "Extracting ./t10k-images-idx3-ubyte.gz\n",
      "Extracting ./t10k-labels-idx1-ubyte.gz\n",
      "INFO:tensorflow:Restoring parameters from ./mlp\n",
      "Test ACC: 0.976\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "\n",
    "mnist = input_data.read_data_sets(\"./\", one_hot=False)\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    \n",
    "    saver1 = tf.train.import_meta_graph('./mlp.meta')\n",
    "    saver1.restore(sess, save_path='./mlp')\n",
    "    \n",
    "    test_acc = sess.run('accuracy:0', feed_dict={'features:0': mnist.test.images,\n",
    "                                                 'targets:0': mnist.test.labels})\n",
    "    print('Test ACC: %.3f' % test_acc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Working with NumPy Archive Files and Creating Non-Trainable Graphs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Export Model Parameters to NumPy NPZ files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from ./mlp\n",
      "Found variables:\n",
      "fc1/weights:0\n",
      "fc1/biases:0\n",
      "fc2/weights:0\n",
      "fc2/biases:0\n",
      "logits/weights:0\n",
      "logits/biases:0\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "tf.reset_default_graph()\n",
    "with tf.Session() as sess:\n",
    "\n",
    "    saver1 = tf.train.import_meta_graph('./mlp.meta')\n",
    "    saver1.restore(sess, save_path='./mlp')\n",
    "    \n",
    "    var_names = [v.name for v in \n",
    "                 tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)]\n",
    "    \n",
    "    params = {}\n",
    "    print('Found variables:')\n",
    "    for v in var_names:\n",
    "        print(v)\n",
    "        \n",
    "        ary = sess.run(v)\n",
    "        params[v] = ary\n",
    "        \n",
    "    np.savez('mlp', **params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load NumPy .npz files into the `mlp_graph`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that the graph def was set up in a way that it constructs \"rigid,\" not trainable TensorFlow classifier if .npz files are provided. This is on purpose, since it may come handy in certain use cases, but the code can be easily modified to make the model trainable if NumPy .npz files are provided (e.g., by wrapping the `tf.constant` calls in `fc_layer` in a `tf.Variable` constructor."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note: If you defined the `fc_layer` and `mlp_graph` wrapper functions in *Define Multilayer Perceptron Graph*, the following code cell is otherwise independent and has no other code dependencies.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./train-images-idx3-ubyte.gz\n",
      "Extracting ./train-labels-idx1-ubyte.gz\n",
      "Extracting ./t10k-images-idx3-ubyte.gz\n",
      "Extracting ./t10k-labels-idx1-ubyte.gz\n",
      "Test ACC: 0.976\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "\n",
    "###########################\n",
    "### LOAD DATA AND PARAMS\n",
    "###########################\n",
    "\n",
    "mnist = input_data.read_data_sets(\"./\", one_hot=False)\n",
    "param_dict = np.load('mlp.npz')\n",
    "\n",
    "##########################\n",
    "### GRAPH DEFINITION\n",
    "##########################\n",
    "\n",
    "\n",
    "g = tf.Graph()\n",
    "with g.as_default():\n",
    "    \n",
    "    # here: constructs a non-trainable graph\n",
    "    # due to the provided fixed_params argument\n",
    "    mlp_graph(fixed_params=param_dict)\n",
    "\n",
    "with tf.Session(graph=g) as sess:\n",
    "    \n",
    "    test_acc = sess.run('accuracy:0', feed_dict={'features:0': mnist.test.images,\n",
    "                                                 'targets:0': mnist.test.labels})\n",
    "    print('Test ACC: %.3f' % test_acc)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
